{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " # Tutorial 7: Neural Network Inference\n",
    "Our library can support inference of neural networks based on secret sharing. Here we present a tutorial of neural network inference using secure two-party computation and secure three-party computation. Similar to Tutorial_2, we simulate multiple parties using multi-threads and trusted third parties which provide auxiliary parameters using local files. Models are shared before the prediction, and data is shared during the prediction process. \n",
    "\n",
    "You can refer to `./debug/application/neural_network/2pc/neural_network_client.py` and `./debug/application/neural_network/2pc/neural_network_server.py` for examples of actual usage of the neural network in 2pc, refer to `./debug/application/neural_network/3pc/P0.py`, `./debug/application/neural_network/3pc/P1.py` and `./debug/application/neural_network/3pc/P2.py` for examples of actual usage of the neural network in 3pc.\n",
    " In this tutorial, we use AlexNet as an example. First, train the model using `data/AlexNet/Alexnet_MNIST_train.py`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.utils.data\n",
    "import torchvision\n",
    "from NssMPC.application.neural_network.utils.converter import load_model\n",
    "from NssMPC.application.neural_network.utils.converter import share_model\n",
    "from NssMPC.application.neural_network.utils.converter import share_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-12T10:02:15.236876Z",
     "start_time": "2024-04-12T10:02:14.573108Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start Training!\n",
      "[1, 60000] loss:0.4353\n",
      "[2, 60000] loss:0.1440\n",
      "[3, 60000] loss:0.0828\n",
      "[4, 60000] loss:0.1156\n",
      "[5, 60000] loss:0.0364\n",
      "Finished Training\n",
      "time 6.053506851196289\n",
      "Accuracy of the network on the 100 test images:97.49%\n"
     ]
    }
   ],
   "source": [
    "# training AlexNet\n",
    "exec(open('..\\data\\AlexNet\\Alexnet_MNIST_train.py').read())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If there is a path error problem, it is because the code needs to download the model in advance.\n",
    "And then, import the following packages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-12T10:02:16.818196Z",
     "start_time": "2024-04-12T10:02:15.238489Z"
    }
   },
   "outputs": [],
   "source": [
    "from data.AlexNet.Alexnet import AlexNet\n",
    "import NssMPC.config.configs\n",
    "from NssMPC.application.neural_network.party.neural_network_party import NeuralNetworkCS"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now, we can create our two parties."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the server as the model provider and the client as the data provider, we need to generate triples for matrix multiplication in advance and distribute them to both parties. Similar to Tutorial_2, we simulate this process on the server side.\n",
    "The model provider also needs to import the following packages to share the model and data owner needs to import the following packages to share the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-12T10:02:17.961811Z",
     "start_time": "2024-04-12T10:02:16.836197Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TCPServer waiting for connection ......\n",
      "TCPServer waiting for connection ......\n",
      "successfully connect to server 127.0.0.1:8000\n",
      "successfully connect to server 127.0.0.1:9000\n",
      "TCPServer successfully connected by :('127.0.0.1', 8200)\n",
      "TCPServer successfully connected by :('127.0.0.1', 9100)\n"
     ]
    }
   ],
   "source": [
    "import threading\n",
    "\n",
    "# set Server\n",
    "server = NeuralNetworkCS(type='server')\n",
    "\n",
    "def set_server():\n",
    "    # CS connect\n",
    "    \n",
    "    # server.connect(server_server_address, server_client_address, client_server_address, client_client_address)\n",
    "    server.online()\n",
    "    \n",
    "# set Client\n",
    "client = NeuralNetworkCS(type='client')\n",
    "\n",
    "def set_client():\n",
    "    # CS connect\n",
    "    \n",
    "    # client.connect(client_server_address, client_client_address, server_server_address, server_client_address)\n",
    "    client.online()\n",
    "    \n",
    "server_thread = threading.Thread(target=set_server)\n",
    "client_thread = threading.Thread(target=set_client)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model provider needs to provide and share the model.The data provider needs to provide data. Because neural network inference involves matrix multiplication, before starting the prediction, we need to simulate one prediction and generate the required matrix Beaver triples ahead of time. The above steps are the preparation work. Before starting inference, the data provider needs to share its data. And then, the two parties load their respective model shares and start inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted result is:  tensor([7])\n",
      "Predicted result is:  tensor([2])\n",
      "Predicted result is:  tensor([1])\n",
      "Predicted result is:  tensor([0])\n",
      "Predicted result is:  tensor([4])\n",
      "Communicator: 127.0.0.1 closed.\n",
      "Communication costs:\n",
      "\tsend rounds: 307\t\tsend bytes: 197.3940315246582 MB.\n",
      "\trecv rounds: 212\t\trecv bytes: 163.3752555847168 MB.\n",
      "Communicator: 127.0.0.1 closed.\n",
      "Communication costs:\n",
      "\tsend rounds: 212\t\tsend bytes: 163.3752555847168 MB.\n",
      "\trecv rounds: 307\t\trecv bytes: 197.3940315246582 MB.\n"
     ]
    }
   ],
   "source": [
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Subset\n",
    "from NssMPC.common.utils import get_time\n",
    "from NssMPC.config import NN_path, DEVICE\n",
    "def server_predict():\n",
    "    net = AlexNet()\n",
    "\n",
    "    model_path = NN_path + '/AlexNet_MNIST.pkl'\n",
    "    net.load_state_dict(torch.load(model_path))\n",
    "\n",
    "    shared_param, shared_param_for_other = share_model(net)\n",
    "    server.send(shared_param_for_other)\n",
    "    \n",
    "    num = server.dummy_model(net)\n",
    "    net = load_model(net, shared_param)\n",
    "    while num:\n",
    "        shared_data = server.receive()\n",
    "        server.inference(net, shared_data)\n",
    "        num -= 1\n",
    "    # close party after inference\n",
    "    server.close()\n",
    "    \n",
    "def client_predict():\n",
    "    net = AlexNet()\n",
    "    transform1 = transforms.Compose([transforms.ToTensor()])\n",
    "    test_set = torchvision.datasets.MNIST(root=NN_path, train=False, download=True, transform=transform1)\n",
    "    indices = list(range(5))\n",
    "    subset_data = Subset(test_set, indices)\n",
    "    test_loader = torch.utils.data.DataLoader(subset_data, batch_size=1, shuffle=False, num_workers=0)\n",
    "\n",
    "    shared_param = client.receive()\n",
    "    num = client.dummy_model(test_loader)\n",
    "    net = load_model(net, shared_param)\n",
    "    \n",
    "    for data in test_loader:\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        images, labels = data\n",
    "\n",
    "        images = images.to(DEVICE)\n",
    "        labels = labels.to(DEVICE)\n",
    "\n",
    "        shared_data, shared_data_for_other = share_data(images)\n",
    "        client.send(shared_data_for_other)\n",
    "\n",
    "        res = client.inference(net, shared_data)\n",
    "\n",
    "        _, predicted = torch.max(res, 1)\n",
    "\n",
    "        print(\"Predicted result is: \", predicted)\n",
    "\n",
    "    # close party after inference\n",
    "    client.close()\n",
    "    \n",
    "server_thread = threading.Thread(target=server_predict)\n",
    "client_thread = threading.Thread(target=client_predict)\n",
    "\n",
    "server_thread.start()\n",
    "client_thread.start()\n",
    "client_thread.join()\n",
    "server_thread.join()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "P0 instantiates an AlexNet model and loads the weights of the pre-trained model, and P1 also instantiates an AlexNet model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the prediction results as above, the core statements used by our library for neural network prediction are `server.inference` and `client.inference`. If you wish to perform additional operations on the prediction results, you can process them according to your specific requirements.\n",
    "In [data/neural_network/AlexNet/Alexnet.py](https://gitee.com/xdnss/mpctensorlib/tree/master/data/neural_network/AlexNet/Alexnet.py) and [data/neural_network/ResNet/ResNet.py](https://gitee.com/xdnss/mpctensorlib/tree/master/data/neural_network/ResNet/ResNet.py), we provide the training code and pre-trained models for AlexNet and ResNet50. You can use them to train models according to your specific requirements and perform inference using these trained models."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
